# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""DCN module"""
import mindspore
from mindspore import nn
from mindspore import ops
from mindspore.ops import operations as P


class BatchMapCoordinates(nn.Cell):
    """map_coordinates.Only supports 2D feature maps

    Args:
        x(tensor):shape(b, h, w)
        coords(tensor):Evaluate the entered coordinates, shape(b, n_points, 2)

    Returns:
        Tensor, All pixel values of offset coordinates
    """

    def __init__(self, order=1):
        super(BatchMapCoordinates, self).__init__()
        self.order = order
        self.cast = P.Cast()
        self.stack = P.Stack(axis=-1)
        self.gather_nd = P.GatherNd()
        self.reshape = P.Reshape()
        self.floor = ops.Floor()
        self.ceil_op = ops.Ceil()

    def construct(self, x, coords):
        """construct map_coordinates"""
        input_shape = x.shape
        batch_size = input_shape[0]
        input_size = input_shape[1]
        n_coords = coords.shape[1]

        min_value = self.cast(0, mindspore.float32)
        max_value = self.cast(input_size-1, mindspore.float32)
        coords = ops.clip_by_value(coords, min_value, max_value)
        coords_lt = self.cast(self.floor(coords), mindspore.int32)
        # TODO ops.ceil 不支持GPU
        coords_rb = self.cast(self.floor(coords + 1), mindspore.int32)
        coords_lb = self.stack([coords_lt[..., 0], coords_rb[..., 1]])
        coords_rt = self.stack([coords_rb[..., 0], coords_lt[..., 1]])

        out = nn.Range(batch_size)
        idx_in = out()
        idx = ops.repeat_elements(idx_in, n_coords)
        # Get the pixel values of the four corners
        # lt
        indices = self.stack([
            idx, self.reshape(coords_lt[..., 0], (-1,)),
            self.reshape(coords_lt[..., 1], (-1,))
        ])
        vals_lt = self.gather_nd(x, indices)
        vals_lt = self.reshape(vals_lt, (batch_size, n_coords))
        # rb
        indices = self.stack([
            idx, self.reshape(coords_rb[..., 0], (-1,)),
            self.reshape(coords_rb[..., 1], (-1,))
        ])
        vals_rb = self.gather_nd(x, indices)
        vals_rb = self.reshape(vals_rb, (batch_size, n_coords))
        # lb
        indices = self.stack([
            idx, self.reshape(coords_lb[..., 0], (-1,)),
            self.reshape(coords_lb[..., 1], (-1,))
        ])
        vals_lb = self.gather_nd(x, indices)
        vals_lb = self.reshape(vals_lb, (batch_size, n_coords))
        # rt
        indices = self.stack([
            idx, self.reshape(coords_rt[..., 0], (-1,)),
            self.reshape(coords_rt[..., 1], (-1,))
        ])
        vals_rt = self.gather_nd(x, indices)
        vals_rt = self.reshape(vals_rt, (batch_size, n_coords))

        # Perform bilinear interpolation to get the pixel value of the target coordinate
        coords_offset_lt = coords - self.cast(coords_lt, mindspore.float32)
        vals_t = vals_lt + (vals_rt - vals_lt) * coords_offset_lt[..., 0]
        vals_b = vals_lb + (vals_rb - vals_lb) * coords_offset_lt[..., 0]
        mapped_vals = vals_t + (vals_b - vals_t) * coords_offset_lt[..., 1]
        return mapped_vals


class BatchMapOffsets(nn.Cell):
    """DCN.Batch map offsets into x

    Args:
        x(tensor):shape(bc, h, w)
        offsets(tensor):shape(bc, h, w, 2)

    Returns:
        Tensor, All pixel values of offset coordinates
    """

    def __init__(self, order=1):
        super(BatchMapOffsets, self).__init__()
        self.order = order
        self.reshape = P.Reshape()
        self.stack = P.Stack(axis=-1)
        self.cast = P.Cast()
        self.meshgrid = P.Meshgrid(indexing="ij")
        self.ms_batch_map_coordinates = BatchMapCoordinates()
        self.expand_dims = P.ExpandDims()
        self.tile = P.Tile()

    def construct(self, x, offsets):
        """construct dcn layer"""
        input_shape = x.shape
        batch_size = input_shape[0]
        input_size = input_shape[1]
        mutiples = (batch_size, 1, 1)

        offsets = self.reshape(offsets, (batch_size, -1, 2))

        input1, input2 = nn.Range(input_size), nn.Range(input_size)
        input1, input2 = input1(), input2()
        iput = (input1, input2)
        grid = self.meshgrid(iput)
        grid = self.stack(grid)
        grid = self.cast(grid, mindspore.float32)
        grid = self.reshape(grid, (-1, 2))
        grid = self.expand_dims(grid, 0)
        grid = self.tile(grid, mutiples)
        # The coordinates of each channel are added with an offset
        coords = offsets + grid

        mapped_vals = self.ms_batch_map_coordinates(x, coords)
        return mapped_vals
